iT邦幫忙

2022 iThome 鐵人賽

DAY 14
0
AI & Data

JAX 好好玩系列 第 14

JAX 好好玩 (14) : JAX JIT (3) : 函式內陣列的維度

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

我們現在來探討另外一個 JAX JIT 使用上的限制 [14.1]:函式運算的過程中,所有的陣列都必須是靜態的維度 ( it requires all arrays to have static shapes.)。這是什麼意思?

先看一個符合標準的例子:

def norm_tedious(X):
    mean = X.mean(0)
    std = X.std(0)
    y = (X - mean) / std
    return y, mean, std

input = jnp.array([1,2,3,4])
print(f'input shape : {input.shape}')
y, mean, std = norm_tedious(input)
print(f'y shape : {y.shape}')
print(f'mean shape : {mean.shape}')
print(f'std shape : {std.shape}')

output:
*input shape : (4,)
y shape : (4,)
mean shape : ()
std shape : ()

當輸入參數 X 的維度固定之後,這個函式內所有的陣列運算維度都是固定的,這個就是靜態的維度。注意!它並不是要求輸入參數 X 的維度要固定,X 可以是任何合理的維度;它是說當 X 在某一個維度時,函式內的陣列維度都是固定的。

上例中,老頭故意回傳函式 norm_tedious() 內的運算過程中的變數,並印出它們的維度。當輸入參數 input 的維度是 (4,) 時,輸出值 y 的維度必然是 (4,),而中間運算結果 mean 和 std 必然是純量。

下面的例子就不符合標準:

def get_negatives(X):
    y = X[X<0]
    return y
 
get_negatives_jit = jax.jit(get_negatives)

input1 = jnp.array([1,2,3,4])
input2 = jnp.array([-1,2,-3,4])
 
output1 = get_negatives(input1)
output2 = get_negatives(input2)
 
print(f'input1 shape : {input1.shape}')
print(f'input2 shape : {input2.shape}')
print(f'output1 shape : {output1.shape}')
print(f'output2 shape : {output2.shape}')

output:
*input1 shape : (4,)
input2 shape : (4,)
output1 shape : (0,)
output2 shape : (2,)

函式 get_negatives() 回傳值的維度,不僅僅是由輸入參數 X 的維度決定,還要依據輸入參數的內含值。上例,input1 和 input2 是相同維度但含有不同的值,它們所對應的 output1 和 output2 則不同。如此,就違反了 JAX JIT 的要求。

如果硬要將這個不合適的函式用 JAX JIT 執行,執行時將會報錯:

output = get_negatives_jit(input1)

output:
'---------------------------------------------------------------------------'
*UnfilteredStackTrace Traceback (most recent call last)
in
2
----> 3 output = get_negatives_jit(input1)
…..
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4])

讀者可能會納悶,那麼多的使用限制,我們為什麼還要用 JAX JIT 呢?老頭的看法是效率!它快到我們捨不得放下它。目前 JAX 仍在發展中,截至目前 (2022/09/08) ,它的版本是 0.3.17 (2022/09/01 釋出) [14.2],仍舊是測試版。我們期待未來,JAX 社群能夠發展出更好的工具,讓程式設計師能夠檢查其函式是否符合 JIT 的要求,提出警告及建議,而有效避免程式執行時的意外狀況及錯誤的結果。

[14.1] 本文主要參考了 JAX 文件官網 To JIT or not to JIT 一文的內容。
[14.2] 可由 pypi 的網站查到最新釋出的版本。https://pypi.org/project/jax/


上一篇
JAX 好好玩 (13) : JAX JIT (2) : 純函式 (Pure Function)
下一篇
JAX 好好玩 (15) : JAX JIT (4) : 追踪 (Tracing)
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言